#!/bin/bash

# Please change the following envrioment variables
# base on the cluster configuration
export OMP_NUM_THREADS=4
export MUSA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
export MUSA_KERNEL_TIMEOUT=3200000
export MUSA_BLOCK_SCHEDULE_MODE=1
export ACCELERATOR_BACKEND="musa"
export MCCL_PROTOS=2
export MUSA_PRINT_ENV=1
export MCCL_CHECK_POINTERS=0
# export MCCL_PEER_ACCESS_IPC_FLAG=5
# export MCCL_PEER_ACCESS_FLAG=4
export CUDA_DEVICE_MAX_CONNECTIONS=1
export MCCL_IB_GID_INDEX=3
export MCCL_ALGOS=1
export MCCL_IB_HCA='mlx5_10:1,mlx5_11:1,mlx5_12:1,mlx5_13:1,mlx5_14:1,mlx5_15:1,mlx5_16:1,mlx5_17:1'

# export LD_LIBRARY_PATH=/usr/mpi/gcc/openmpi-4.1.5rc2/lib:/home/apex/amp_C/lib:$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu/musa/
export MCCL_DEBUG=INFO
#export MUSA_LOG=0xffff
export MS_DIST_MEMORY_PATH=/home/dist
# export TORCH_CPP_LOG_LEVEL=INFO
# export TORCH_DISTRIBUTED_DEBUG=INFO

FILEDIR=$(cd $(dirname $0);pwd)
PROJ_HOME=$FILEDIR/..
export MUSA_ERROR_DUMP_PATH=$PROJ_HOME/../
WORK_HOME=$PWD/..
PATCH_HOME="$PWD"/../../..

############################################################## dev install flash-atten #########################################
pushd $PATCH_HOME/../flash-attention
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE python setup.py develop
popd
############################################################## dev install flash-atten #########################################

MEGATRON_PATH=${PATCH_HOME}/../Megatron-LM
MEGATRON_PATH=/home/megatron_0.9/Megatron-LM/
export PYTHONPATH=${MEGATRON_PATH}:${PATCH_HOME}:$PYTHONPATH
if [ ! -d "${MEGATRON_PATH}/build" ]; then
    cd "${MEGATRON_PATH}"
    python setup.py build_ext --inplace
    cd -
fi
# TP_SIZE PP_SIZE WORLD_SIZE MICRO_BATCH_SIZE NUM_MICROBATCHES
WORLD_SIZE=$MLFLOW_WORKER_TOTAL_GPUNUM
TP_SIZE=$MS_TP_SIZE
PP_SIZE=$MS_PP_SIZE
MICRO_BATCH_SIZE=$MS_MICRO_BATCH_SIZE
NUM_MICROBATCHES=$MS_NUM_MICROBATCHES
(( DP_SIZE = $WORLD_SIZE / ($TP_SIZE * $PP_SIZE) ))
echo 'DP SIZE is: '$DP_SIZE
(( GLOBAL_BATCH_SIZE = $MICRO_BATCH_SIZE * $NUM_MICROBATCHES * $DP_SIZE ))
echo 'GLOBAL BATCH SIZE is: '$GLOBAL_BATCH_SIZE
DATA_CACHE_PATH=$WORK_HOME/datasets/cache
EXPNAME=$MLFLOW_EXPERIMENT_NAME
CHECKPOINT_PATH=$WORK_HOME/checkpoints/$EXPNAME
mkdir -p $CHECKPOINT_PATH
DATA_DIR=$MODEL_STUDIO_DATASET_PATH
DATA_FILE_PREFIX=$(ls $DATA_DIR | grep ".idx" | head -n 1 | sed "s/.idx//")
DATA_PATH=$DATA_DIR/$DATA_FILE_PREFIX
TOKENIZED_MODEL=/home/dist/llama3/llama3_tokenizer
LOG_PATH=$WORK_HOME/logs/$EXPNAME
mkdir -p $LOG_PATH
TB_PATH=$WORK_HOME/tboard/$EXPNAME
mkdir -p $TB_PATH
WB_PATH=$WORK_HOME/wandb/$EXPNAME
mkdir -p $WB_PATH


export MASTER_PORT=12361
echo "node rank: ${NODE_RANK}"
# export MCCL_TOPO_DUMP_FILE=$PROJ_HOME/output/$EXPNAME/noderank${NODE_RANK}/nccl_topo.xml
# export MCCL_TOPO_FILE=$MCCL_TOPO_DUMP_FILE

DISTRIBUTED_ARGS=(
    --max-restarts=3 \
    --nproc_per_node $MLFLOW_WORKER_GPUNUM \
    --nnodes $NODE_NUM \
    --log_dir $WORK_HOME/output/$EXPNAME/noderank${NODE_RANK} \
    --redirects 3
)

    # --accelerator=mthreads.com/gpu \
    # --network-check \
#   --redirects 3

MODEL_ARGS=(
    --num-layers 32
    --hidden-size 4096 
    --ffn-hidden-size 14336
    --num-attention-heads 32 
    --group-query-attention 
    --num-query-groups 8
    --seq-length 4096 
    --max-position-embeddings 4096 
    --norm-epsilon 1e-5 
    --attention-dropout 0.0 
    --hidden-dropout 0.0 
    --disable-bias-linear 
    --position-embedding-type rope 
    --no-position-embedding 
    --swiglu 
    --normalization RMSNorm
    --untie-embeddings-and-output-weights
)

# 244140625 1T
TRAINING_ARGS=(
    --seed 42 
    --micro-batch-size $MICRO_BATCH_SIZE 
    --global-batch-size $GLOBAL_BATCH_SIZE  
    --train-samples 24414062 
    --init-method-std 0.008
    --use-mcore-models 
    --no-bias-dropout-fusion
    --no-bias-swiglu-fusion
    --use-distributed-optimizer 
    --use-flash-attn 
    --sequence-parallel 
    --recompute-granularity full 
    --recompute-method block 
    --recompute-num-layers 0 
    --distributed-backend nccl 
    --transformer-impl local
)

REGULARIZATION_ARGS=(
    --weight-decay 0.1 
    --adam-beta1 0.9 
    --adam-beta2 0.95 
    --clip-grad 1.0 
)

WARMUP_STEPS=2000
WARMUP_SAMPLES=$((WARMUP_STEPS * GLOBAL_BATCH_SIZE))

LEARNING_RATE_ARGS=(
    --lr 1.5e-5 
    --lr-decay-style cosine 
    --lr-warmup-samples ${WARMUP_SAMPLES} 
    --min-lr 1.5e-6 
    --initial-loss-scale 65536 
    --min-loss-scale 1.0 
)

MODEL_PARALLEL_ARGS=(
	--tensor-model-parallel-size $TP_SIZE  
	--pipeline-model-parallel-size $PP_SIZE 
    --decoder-last-pipeline-num-layers 14
)

MIXED_PRECISION_ARGS=(
    --bf16 
    --attention-softmax-in-fp32 
    --no-masked-softmax-fusion
    --accumulate-allreduce-grads-in-fp32
)

DATA_ARGS=(
    --data-path $DATA_PATH
    --data-cache-path $DATA_CACHE_PATH
    --tokenizer-type HuggingFaceTokenizer
    --tokenizer-model ${TOKENIZED_MODEL}
    --split 1
)

# DATA_ARGS=(
#     --data-path $DATA_PATH 
#     --vocab-file $VOCAB_FILE 
#     --merge-file $MERGE_FILE 
#     --split 949,50,1
# )



EVAL_AND_LOGGING_ARGS=(
    --log-interval 1
    --save-interval 200000 
    --eval-interval 1000 
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH 
    --eval-iters 0
    --tensorboard-dir $TB_PATH 
)
#ASYNC_CKPT_LOCAL_DIR=/checkpoint-cache #${MODEL_STUDIO_CHECKPOINT_CACHE_DIR}  # 本地盘存储目录
ASYNC_CKPT_MEM_DIR="/dev/shm" # 内存

ASYNC_CKPT_ARGS="
   --enable_async_ckpt \
   --mem_dir $ASYNC_CKPT_MEM_DIR
"

export ENABLE_ASYNC_CKPT=0

LOGGING_ARGS="
    --log-interval 1 \
    --log-params-norm
"
# \
#     --log-params-norm

TORCHRUN=dlrover-run

cmd="$TORCHRUN ${DISTRIBUTED_ARGS[@]} $WORK_HOME/pretrain_gpt.py \
        ${MODEL_ARGS[@]} \
        ${TRAINING_ARGS[@]} \
        ${REGULARIZATION_ARGS[@]}
        ${LEARNING_RATE_ARGS[@]} \
        ${MODEL_PARALLEL_ARGS[@]} \
        ${MIXED_PRECISION_ARGS[@]} \
        ${DATA_ARGS[@]} \
        ${EVAL_AND_LOGGING_ARGS[@]}
"
echo $cmd
eval $cmd
